import socket
import sys
import os
import shutil
import torch
import wandb
from runners.scratchRunner import scratchRunner
from runners.pretrainedRunner import pretrainedRunner
from strategies import scratchStrategies
# Default wandb parameters
defaults = dict(
    # System
    run_id=None,
    computer=socket.gethostname(),
    collect_class_statistics=False,
    # Setup
    dataset=None,
    arch=None,
    n_epochs=None,
    batch_size=None,
    # Effiency
    use_amp=True,
    # Optimizer
    optimizer=None,
    learning_rate=None,
    n_epochs_warmup=None,  # number of epochs to warmup the lr, should be an int
    momentum=None,
    weight_decay=None,
    wd_schedule=None,
    decouple_wd=None,
    # Sparsifying strategy
    strategy=None,
    use_pretrained=None,
    goal_sparsity=None,
    pruning_selector=None,  # must be in ['global', 'uniform', 'random', 'LAMP']
    correct_batch_statistics=None,
    # Retraining
    n_phases=None,  # Should be 1, except when using IMP
    n_epochs_per_phase=None,
    n_epochs_to_split=None,
    retrain_schedule=None,
    retrain_schedule_warmup=None,
    retrain_schedule_init=None,
    retrain_wd=None,
    retrain_wd_schedule=None,
    retrain_adaptive_in_every_cycle=None,
    # Penalize*
    group_penalty=None,
    # SparseFW
    lmo=None,
    lmo_mode=None,
    lmo_ord=None,
    lmo_value=None,
    lmo_k=None,
    lmo_rescale=None,
    lmo_global=None,
    lmo_delay=None,
    lmo_nuc_method=None,
    lmo_tighten_interval=None,
    lmo_loosen_eps=None,
    lmo_tighten_eps=None,
    lmo_tighten_redistribute=None,
    # CRAM
    cram_k=None,
    cram_rho=None,
    cram_plus=None,
    # ABFP
    abfp_k=None,
    # GL
    lasso_tradeoff=None,
    # SFP
    sfp_k=None,
    sfp_start_epoch=None,
)

if '--debug' in sys.argv:
    defaults.update(dict(
        # System
        run_id=1,
        computer=socket.gethostname(),
        collect_class_statistics=False,
        # Setup
        dataset='mnist',
        arch='SimpleCNN',
        n_epochs=4,
        batch_size=1028,
        # Effiency
        use_amp=True,
        # Optimizer
        optimizer='SFW',
        learning_rate='(Linear, 0.1, 0.0001)',
        n_epochs_warmup=None,  # number of epochs to warmup the lr, should be an int
        momentum=0.9,
        weight_decay=0.0001,
        wd_schedule='(LastOnly, 0.4)',
        decouple_wd=False,
        # Sparsifying strategy
        strategy='SVDEnergyIteration',
        use_pretrained=None,
        goal_sparsity=0.5,
        pruning_selector='uniform',  # must be in ['global', 'uniform', 'random', 'LAMP']
        correct_batch_statistics=True,
        # Retraining
        n_phases=2,  # Should be 1, except when using IMP
        n_epochs_per_phase=1,
        n_epochs_to_split=None,
        retrain_schedule='ALLR',
        retrain_schedule_warmup=None,
        retrain_schedule_init=None,
        retrain_wd=0.0005,
        retrain_wd_schedule='(InitialOnly, 0.5)',
        retrain_adaptive_in_every_cycle=True,
        # Penalize*
        group_penalty=0.005,
        # SparseFW
        lmo='GroupKSupportNormBall',
        lmo_mode='initialization',
        lmo_ord=None,
        lmo_value=4,
        lmo_k=0.1,
        lmo_rescale='gradient',
        lmo_global=False,
        lmo_delay=0.25,
        lmo_nuc_method='qrpartial',
        lmo_tighten_interval=1,
        lmo_tighten_eps=0.25,
        lmo_loosen_eps=0.2,
        lmo_tighten_redistribute=True,
        # CRAM
        cram_k=0.7,
        cram_rho=0.001,
        cram_plus=True,
        # ABFP
        abfp_k=0.3,
        # GL
        lasso_tradeoff=0.5,
        # SFP
        sfp_k=0.2,
        sfp_start_epoch=2,
    ))

# Configure wandb logging
wandb.init(
    config=defaults,
    project='test-000',  # automatically changed in sweep
    entity=None,  # automatically changed in sweep
)
config = wandb.config
ngpus = torch.cuda.device_count()
if ngpus > 0:
    config.update(dict(device='cuda:0'))
else:
    config.update(dict(device='cpu'))

# At the moment, IMP is the only strategy that requires a pretrained model, all others start from scratch
if config.use_pretrained is not None:
    # Use the pretrainedRunner
    runner = pretrainedRunner(config=config)
else:
    # Use the scratchRunner
    try:
        check_for_strategy_existence = getattr(scratchStrategies, config.strategy)
    except Exception as e:
        raise NotImplementedError("Strategy does not exist, potentially forgot to specify 'use_pretrained'.")
    runner = scratchRunner(config=config)
runner.run()

# Close wandb run
wandb_dir_path = wandb.run.dir
wandb.join()

# Delete the local files
if os.path.exists(wandb_dir_path):
    shutil.rmtree(wandb_dir_path)
# Delete temporary directory
if os.path.exists(runner.tmp_dir):
    shutil.rmtree(runner.tmp_dir)
